Skip to content

feat: 7-class tissue classifier on Virchow2 embeddings#13

Open
vojtech-cifka wants to merge 2 commits into
mainfrom
feature/tissue-linear-virchow2
Open

feat: 7-class tissue classifier on Virchow2 embeddings#13
vojtech-cifka wants to merge 2 commits into
mainfrom
feature/tissue-linear-virchow2

Conversation

@vojtech-cifka

@vojtech-cifka vojtech-cifka commented Jun 4, 2026

Copy link
Copy Markdown

Deploys a 7-class tissue classifier as a linear head over the Virchow2 foundation model. Per tile: apply Virchow2's image transform, fetch the ViT token sequence from the deployed virchow2 service via a Ray Serve handle, pool tokens (class token + mean of patch tokens) into a 2560-d embedding, run the ONNX linear head, and emit a 7-channel softmax probability map for HeatmapBuilder. The hard class map is recoverable via argmax over channels at full resolution.

The ONNX linear head is exported from the Virchow2 + LBFGS final linear classifier (MLflow run 0e2230c722134ce0985e09a18ccadf75, artifacts/onnx/linear_head.onnx).

Files:

  • models/tissue_linear.py: Serve deployment. torch/PIL/timm imports are lazy (the head node builds the app graph without them; the replica runs on GPU workers that carry them). ONNX runs on CPUExecutionProvider.
  • helm/rayservice/applications/tissue-linear.yaml: app definition (num_gpus: 1 to land on the mig20 GPU workers for torch/timm).
  • helm/rayservice/values.yaml: register tissue-linear.

Validated on the full WSI 07 Leiomyosarkom.svs via HeatmapBuilder, producing a (52224, 36864, 7) BigTIFF; argmax over channels yields Other (neoplastic) and Connective-Tissue (stroma) dominant, consistent with a leiomyosarcoma. You can check the resulting segmentation map.

Summary by CodeRabbit

  • New Features
    • Added a new tissue-linear inference app exposed at /tissue-linear.
    • Added a POST / endpoint that ingests LZ4-compressed tile payloads and returns per-tile class probability maps.
    • Enabled batched GPU-backed inference with autoscaling (0–2 replicas) and tuned request limits for consistent throughput.
    • Uses a selectable foundation model and an MLflow-tracked ONNX artifact to drive inference.

@vojtech-cifka vojtech-cifka requested review from a team, Jurgee, ejdam87 and matejpekar June 4, 2026 17:06
@coderabbitai

coderabbitai Bot commented Jun 4, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

Adds a TissueLinear Ray Serve application and Helm registration plus a FastAPI-backed Ray deployment that LZ4-decompresses tiles, obtains Virchow2 embeddings remotely, runs an ONNX linear head in batches, and returns per-class probability maps.

Changes

TissueLinear Tissue Classification Service

Layer / File(s) Summary
Helm Configuration & Application Registration
helm/rayservice/applications/tissue-linear.yaml, helm/rayservice/values.yaml
Registers the tissue-linear RayService application with import_path: models.tissue_linear:app, route_prefix: /tissue-linear, runtime env, autoscaling (0–2), request limits, and model provider config.
Configuration Schema & Deployment Class Definition
models/tissue_linear.py
Adds Config TypedDict with tile sizing, batching, and model provider fields; creates module-level FastAPI instance and TissueLinear Ray Serve deployment class with __init__ storing LZ4 frame module.
Model Resolution & ONNX Session Setup
models/tissue_linear.py
Implements reconfigure() to build Virchow2 transform, resolve ONNX model file from provider target, create ONNX Runtime session (CPU provider), capture tensor metadata, and set Serve batching.
Configuration Endpoint
models/tissue_linear.py
Adds get_config() returning runtime subset (tile_size, output_tile_size, n_channels, mpp).
Tile Preprocessing & Remote Embedding
models/tissue_linear.py
Implements _prepare_tile_for_virchow2() to convert CHW→HWC→PIL→transform and _create_embedding() to call remote Virchow2 service for token embeddings, pool tokens, and return float32 embeddings.
Batched Prediction & ONNX Inference
models/tissue_linear.py
@serve.batch decorated predict() creates embeddings concurrently, stacks them, runs ONNX linear head, reshapes outputs to (n_classes,1,1) probability maps.
HTTP Ingress, LZ4 Handling & Module Binding
models/tissue_linear.py
root() handler decompresses LZ4 request body, reconstructs RGB tile, transposes to CHW, calls predict(), returns JSON-serializable probabilities; exports app = TissueLinear.bind().

Sequence Diagram

sequenceDiagram
  participant Client
  participant FastAPI_root
  participant TissueLinear_predict
  participant ThreadPool
  participant Virchow2Service
  participant ONNXRuntime
  Client->>FastAPI_root: POST /tissue-linear (LZ4 tile)
  FastAPI_root->>FastAPI_root: decompress & reshape to CHW
  FastAPI_root->>TissueLinear_predict: predict(tiles)
  TissueLinear_predict->>ThreadPool: _prepare_tile_for_virchow2(tile)
  ThreadPool-->>TissueLinear_predict: transformed tensor
  TissueLinear_predict->>Virchow2Service: get_app_handle (token embeddings)
  Virchow2Service-->>TissueLinear_predict: token embeddings
  TissueLinear_predict->>ONNXRuntime: run(linear head on batch)
  ONNXRuntime-->>TissueLinear_predict: logits
  TissueLinear_predict-->>FastAPI_root: probability maps
  FastAPI_root-->>Client: JSON response
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • RationAI/model-service#2: Implements the Virchow2 foundation-model deployment and preprocessing used by this TissueLinear service.

Suggested reviewers

  • Adames4
  • JakubPekar
  • ejdam87

Poem

🐰 I nibble bytes and stitch a flow,
Tiles unwarp and embers glow,
Virchow2 whispers tokens bright,
ONNX hums the classed light,
Ray batches hop into the night.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately and specifically describes the main change: deploying a 7-class tissue classifier on Virchow2 embeddings, which is the core objective of the PR.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feature/tissue-linear-virchow2

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new Ray Serve application, tissue-linear, which implements a 7-class tissue classifier using a linear head over Virchow2 embeddings. The feedback focuses on several key performance and resource optimizations: releasing the unused GPU resource in the Helm configuration, constructing the image transform directly instead of instantiating the heavy Virchow2 model to save memory, eliminating redundant array transpositions between CHW and HWC formats, and offloading the CPU-bound ONNX inference to a thread pool to avoid blocking the async event loop.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread helm/rayservice/applications/tissue-linear.yaml
Comment thread models/tissue_linear.py Outdated
Comment thread models/tissue_linear.py
Comment thread models/tissue_linear.py Outdated
Comment thread models/tissue_linear.py
@vojtech-cifka vojtech-cifka changed the title Add tissue-linear: 7-class tissue classifier on Virchow2 embeddings feat: 7-class tissue classifier on Virchow2 embeddings Jun 4, 2026
@vojtech-cifka vojtech-cifka force-pushed the feature/tissue-linear-virchow2 branch from 26ae1d7 to 8aca5aa Compare June 4, 2026 17:12

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@helm/rayservice/applications/tissue-linear.yaml`:
- Around line 16-18: The ray_actor_options currently reserves a GPU (num_gpus:
1) even though this replica is CPU-only; remove or set num_gpus to 0 in the
ray_actor_options block to avoid claiming GPU MIG slots, and instead use a node
label/custom resource or nodeAffinity for placement if you need GPU-hosted
nodes; note that models/tissue_linear.py explicitly uses CPUExecutionProvider
(lines ~106-109) and embeddings are fetched from the remote virchow2 app, so no
local CUDA is required.
- Line 7: The working_dir currently points to a moving ref
(https://github.com/RationAI/model-service/archive/refs/heads/main.zip); update
the working_dir value in the tissue-linear.yaml to an immutable archive URL (a
release tag or commit SHA zip, e.g.
https://github.com/RationAI/model-service/archive/<COMMIT_SHA>.zip) so the chart
pulls a fixed revision; change the value for the working_dir key and regenerate
any chart lock or documentation that records the pinned revision.

In `@models/tissue_linear.py`:
- Around line 179-186: In root(Request) (models/tissue_linear.py) validate the
decompressed payload length before reshaping: compute expected_size =
self.tile_size * self.tile_size * 3, attempt decompression but constrain or
check output size from self.lz4.decompress, and if the decompressed length !=
expected_size return a 400 (e.g., raise HTTPException(status_code=400) or return
a 400 Response) instead of reshaping; only call np.frombuffer(...).reshape(...)
when the length matches exactly to avoid oversized allocations and malformed
payloads.
- Around line 59-61: Validate the advertised output contract in reconfigure():
check that config["output_tile_size"] == 1 and config["n_channels"] ==
self._num_classes (use the instance attributes output_tile_size and n_channels
and the model property self._num_classes) and raise a clear ValueError if either
check fails so callers (and HeatmapBuilder) fail fast; add the same validation
where the config is parsed/assigned (the other block around the second
assignment of tile_size/output_tile_size/n_channels) to ensure consistency
across reconfigure() and any alternate config path.
- Line 138: Wrap the direct RPC await to
self.foundation_model.predict.remote(tile_tensor) in an async timeout (e.g.,
asyncio.wait_for) with a configurable timeout value and handle
asyncio.TimeoutError: on timeout, cancel the Ray object ref
(ray.cancel(virchow2_output, force=True)), log the timeout including the model
and tile context, and surface a clear error or fallback so the caller (methods
in models/tissue_linear.py that call foundation_model.predict.remote) can fail
fast instead of hanging; make the timeout value configurable via existing
config/constants and ensure the cancel + logging occurs before re-raising or
returning an error response.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 6580bdcd-1f07-4b23-9f23-971331a6326b

📥 Commits

Reviewing files that changed from the base of the PR and between 3d37a6e and 26ae1d7.

📒 Files selected for processing (3)
  • helm/rayservice/applications/tissue-linear.yaml
  • helm/rayservice/values.yaml
  • models/tissue_linear.py

Comment thread helm/rayservice/applications/tissue-linear.yaml
Comment thread helm/rayservice/applications/tissue-linear.yaml
Comment thread models/tissue_linear.py
Comment thread models/tissue_linear.py
Comment thread models/tissue_linear.py Outdated
@vojtech-cifka vojtech-cifka force-pushed the feature/tissue-linear-virchow2 branch from 8aca5aa to 8be7cab Compare June 4, 2026 17:30

@matejpekar matejpekar left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check the comments on the other PR #12

@vojtech-cifka vojtech-cifka force-pushed the feature/tissue-linear-virchow2 branch from 8be7cab to 473042b Compare June 4, 2026 18:25

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@models/tissue_linear.py`:
- Around line 175-183: The current predict() call uses embeddings = await
asyncio.gather(*(self._create_embedding(tile) for tile in tiles)) which will
abort the whole batch if any _create_embedding(tile) raises; change to use
asyncio.gather(..., return_exceptions=True), iterate the results to detect
exceptions per index, log/record the failing tile index and exception, and
either skip that tile or substitute a default embedding (e.g., zeros of the same
shape) before creating batch = np.stack(...). Ensure you only pass the resulting
valid/substituted embeddings into session.run({self.input_name: batch}) and
preserve mapping from output logits to the original tile indices (so output_name
results align with non-failed tiles). Do not change _create_embedding()’s
no-timeout call to self.foundation_model.predict.remote(tile_tensor).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: d8dbfa32-9ce9-457d-9186-77af54920e5d

📥 Commits

Reviewing files that changed from the base of the PR and between 26ae1d7 and 473042b.

📒 Files selected for processing (3)
  • helm/rayservice/applications/tissue-linear.yaml
  • helm/rayservice/values.yaml
  • models/tissue_linear.py
✅ Files skipped from review due to trivial changes (1)
  • helm/rayservice/values.yaml
🚧 Files skipped from review as they are similar to previous changes (1)
  • helm/rayservice/applications/tissue-linear.yaml

Comment thread models/tissue_linear.py
@vojtech-cifka vojtech-cifka force-pushed the feature/tissue-linear-virchow2 branch 2 times, most recently from 5de47b1 to bacc3c1 Compare June 4, 2026 18:51
@vojtech-cifka vojtech-cifka requested a review from matejpekar June 4, 2026 19:26
@matejpekar matejpekar removed the request for review from ejdam87 June 4, 2026 20:51

@matejpekar matejpekar left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You haven't tackled the biggest issue

Comment thread models/tissue_linear.py Outdated
@vojtech-cifka vojtech-cifka force-pushed the feature/tissue-linear-virchow2 branch from bacc3c1 to b3f3ab4 Compare June 5, 2026 18:53

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@models/tissue_linear.py`:
- Around line 107-113: After creating the ONNX session in the TissueLinear
initializer (the block that sets self.session, self.input_name, self.output_name
and self._num_classes), validate the model's input dimension matches the
expected 2560 Virchow2 embedding width by reading
self.session.get_inputs()[0].shape[-1], converting to int, and raising a clear
exception (or failing fast) if it does not equal 2560; this prevents silent
initialization of a model that emits 7 classes but accepts the wrong input width
and avoids later runtime failures in session.run().
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9cddaba4-49fd-401f-ad08-3d18183753bd

📥 Commits

Reviewing files that changed from the base of the PR and between 473042b and b3f3ab4.

📒 Files selected for processing (3)
  • helm/rayservice/applications/tissue-linear.yaml
  • helm/rayservice/values.yaml
  • models/tissue_linear.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • helm/rayservice/values.yaml
  • helm/rayservice/applications/tissue-linear.yaml

Comment thread models/tissue_linear.py
Deploys a 7-class tissue classifier as a linear head over the Virchow2
foundation model. Per tile: apply Virchow2's image transform, fetch the
ViT token sequence from the deployed `virchow2` service via a Ray Serve
handle, pool tokens (class token + mean of patch tokens) into a 2560-d
embedding, run the ONNX linear head, and emit a 7-channel softmax
probability map for HeatmapBuilder. The hard class map is recoverable
via argmax over channels at full resolution.

The ONNX linear head is exported from the Virchow2 + LBFGS final linear
classifier (MLflow run 0e2230c722134ce0985e09a18ccadf75,
artifacts/onnx/linear_head.onnx).

Files:
- models/tissue_linear.py: Serve deployment. torch/PIL/timm imports are
  lazy (the head node builds the app graph without them; the replica
  runs on GPU workers that carry them). ONNX runs on CPUExecutionProvider.
- helm/rayservice/applications/tissue-linear.yaml: app definition
  (num_gpus: 1 to land on the mig20 GPU workers for torch/timm).
- helm/rayservice/values.yaml: register tissue-linear.

Validated on the full WSI 07 Leiomyosarkom.svs via HeatmapBuilder,
producing a (52224, 36864, 7) BigTIFF; argmax over channels yields
Other (neoplastic) and Connective-Tissue (stroma) dominant, consistent
with a leiomyosarcoma.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@vojtech-cifka vojtech-cifka force-pushed the feature/tissue-linear-virchow2 branch from b3f3ab4 to 64725d2 Compare June 5, 2026 19:30
@vojtech-cifka vojtech-cifka requested a review from matejpekar June 5, 2026 19:32
Comment thread models/tissue_linear.py Outdated
Comment thread models/tissue_linear.py
Comment thread models/tissue_linear.py Outdated
Comment thread helm/rayservice/applications/tissue-linear.yaml
Comment thread models/tissue_linear.py Outdated
- Remove unreachable .onnx rglob fallback; artifact_uri points
  directly at the file, so provider() returns it as-is
- Drop n_channels / output_tile_size / embedding-dim validation guards
  per review
- Shorten verbose comments

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
models/tissue_linear.py (1)

153-159: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Validate decompressed payload size before reshaping.

The handler decompresses arbitrary request bodies and reshapes without verifying the expected size. A malformed payload triggers a 500 at reshape; a highly-compressible body can force a larger allocation than intended. Since the exact raw size is known (tile_size * tile_size * 3), validate before proceeding:

Suggested fix
     `@fastapi.post`("/")
     async def root(self, request: Request) -> list[Any]:
         data = await asyncio.to_thread(self.lz4.decompress, await request.body())
+        expected_size = self.tile_size * self.tile_size * 3
+        if len(data) != expected_size:
+            from fastapi import HTTPException
+            raise HTTPException(
+                status_code=400,
+                detail=f"Expected {expected_size} bytes after decompression, got {len(data)}",
+            )
 
         tile = np.frombuffer(data, dtype=np.uint8).reshape(
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/tissue_linear.py` around lines 153 - 159, The code decompresses
request body data and immediately reshapes it without validating the payload
size, which can cause 500 errors on malformed payloads or force unexpected
memory allocations. After decompressing the data via self.lz4.decompress,
calculate the expected byte size as self.tile_size * self.tile_size * 3 and
validate that the decompressed data length matches this expected size before
calling np.frombuffer and reshape. If the size validation fails, raise an
appropriate error (like ValueError) with a descriptive message indicating the
size mismatch.
🧹 Nitpick comments (1)
models/tissue_linear.py (1)

126-128: 💤 Low value

Consider documenting the Virchow2 token layout.

The magic number 5 in patch_tokens = virchow2_output[:, 5:] assumes Virchow2's specific token layout: [CLS, reg1, reg2, reg3, reg4, patches...]. A brief inline comment noting this (e.g., "skip class token + 4 register tokens") would help future maintainers understand the offset without consulting Virchow2 documentation.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@models/tissue_linear.py` around lines 126 - 128, Add an inline comment above
or on the line containing patch_tokens = virchow2_output[:, 5:] to document the
magic number 5, explaining that it skips Virchow2's class token (CLS) and 4
register tokens (reg1, reg2, reg3, reg4) to extract only the patch tokens. This
will help future maintainers understand the token layout offset without needing
to consult external Virchow2 documentation.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Duplicate comments:
In `@models/tissue_linear.py`:
- Around line 153-159: The code decompresses request body data and immediately
reshapes it without validating the payload size, which can cause 500 errors on
malformed payloads or force unexpected memory allocations. After decompressing
the data via self.lz4.decompress, calculate the expected byte size as
self.tile_size * self.tile_size * 3 and validate that the decompressed data
length matches this expected size before calling np.frombuffer and reshape. If
the size validation fails, raise an appropriate error (like ValueError) with a
descriptive message indicating the size mismatch.

---

Nitpick comments:
In `@models/tissue_linear.py`:
- Around line 126-128: Add an inline comment above or on the line containing
patch_tokens = virchow2_output[:, 5:] to document the magic number 5, explaining
that it skips Virchow2's class token (CLS) and 4 register tokens (reg1, reg2,
reg3, reg4) to extract only the patch tokens. This will help future maintainers
understand the token layout offset without needing to consult external Virchow2
documentation.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: db4c2582-4a43-4bfd-bd7a-ce3e648e45a9

📥 Commits

Reviewing files that changed from the base of the PR and between b3f3ab4 and 67e5e22.

📒 Files selected for processing (3)
  • helm/rayservice/applications/tissue-linear.yaml
  • helm/rayservice/values.yaml
  • models/tissue_linear.py
✅ Files skipped from review due to trivial changes (1)
  • helm/rayservice/values.yaml
🚧 Files skipped from review as they are similar to previous changes (1)
  • helm/rayservice/applications/tissue-linear.yaml

@vojtech-cifka vojtech-cifka requested a review from Jurgee June 14, 2026 21:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants